/**
 * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved.
 * This file is a part of the CANN Open Software.
 * Licensed under CANN Open Software License Agreement Version 2.0 (the "License").
 * Please refer to the License for details. You may not use this file except in compliance with the License.
 * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
 * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
 * See LICENSE in the root of the software repository for the full text of the License.
 */

#include "infershape_case_executor.h"
#include <gtest/gtest.h>
#include "base/registry/op_impl_space_registry_v2.h"

#define DO_INFERSHAPE(infershapeContextPara)                                                                           \
    auto contextFaker = gert::InferShapeContextFaker();                                                                \
    /* 1. input/output information */                                                                                  \
    size_t inputNum = infershapeContextPara.inputTensorDesc_.size();                                                   \
    size_t outputNum = infershapeContextPara.outputTensorDesc_.size();                                                 \
    if (infershapeContextPara.inputInstanceNum_.size() != 0 || infershapeContextPara.outputInstanceNum_.size() != 0) { \
        contextFaker.IrInstanceNum(infershapeContextPara.inputInstanceNum_, infershapeContextPara.outputInstanceNum_); \
    } else {                                                                                                           \
        contextFaker.NodeIoNum(inputNum, outputNum);                                                                   \
    }                                                                                                                  \
    std::vector<gert::Tensor*> inputTensors = {};                                                                      \
    std::vector<gert::StorageShape*> outputShapes = {};                                                                \
    std::vector<std::unique_ptr<gert::Tensor>> inputTensorsKeepAlive = {};                                             \
    for (size_t index = 0; index < inputNum; index++) {                                                                \
        contextFaker.NodeInputTd(                                                                                      \
            index, infershapeContextPara.inputTensorDesc_[index].dtype_,                                               \
            infershapeContextPara.inputTensorDesc_[index].format_,                                                     \
            infershapeContextPara.inputTensorDesc_[index].format_);                                                    \
        std::unique_ptr<gert::Tensor> curTensor = std::make_unique<gert::Tensor>(                                      \
            infershapeContextPara.inputTensorDesc_[index].shape_,                                                      \
            gert::StorageFormat(                                                                                       \
                infershapeContextPara.inputTensorDesc_[index].format_,                                                 \
                infershapeContextPara.inputTensorDesc_[index].format_, gert::ExpandDimsType()),                        \
            gert::TensorPlacement::kOnHost, infershapeContextPara.inputTensorDesc_[index].dtype_,                      \
            infershapeContextPara.inputTensorDesc_[index].isConst_ ?                                                   \
                infershapeContextPara.inputTensorDesc_[index].constValue_ :                                            \
                nullptr);                                                                                              \
        inputTensors.push_back(curTensor.get());                                                                       \
        inputTensorsKeepAlive.push_back(std::move(curTensor));                                                         \
    }                                                                                                                  \
    for (size_t index = 0; index < outputNum; index++) {                                                               \
        contextFaker.NodeOutputTd(                                                                                     \
            index, infershapeContextPara.outputTensorDesc_[index].dtype_,                                              \
            infershapeContextPara.outputTensorDesc_[index].format_,                                                    \
            infershapeContextPara.outputTensorDesc_[index].format_);                                                   \
        outputShapes.push_back(&infershapeContextPara.outputTensorDesc_[index].shape_);                                \
    }                                                                                                                  \
    contextFaker.InputTensors(inputTensors).OutputShapes(outputShapes);                                                \
    for (auto& attrInfo : infershapeContextPara.attrs_) {                                                              \
        switch (attrInfo.attr_.type_) {                                                                                \
            case Ops::Math::AnyValue::ValueType::VT_BOOL: {                                                            \
                contextFaker.Attr(attrInfo.attrName_, *reinterpret_cast<bool*>(attrInfo.attr_.valuePtr_.get()));       \
                break;                                                                                                 \
            }                                                                                                          \
            case Ops::Math::AnyValue::ValueType::VT_INT: {                                                             \
                contextFaker.Attr(attrInfo.attrName_, *reinterpret_cast<int64_t*>(attrInfo.attr_.valuePtr_.get()));    \
                break;                                                                                                 \
            }                                                                                                          \
            case Ops::Math::AnyValue::ValueType::VT_FLOAT: {                                                           \
                contextFaker.Attr(attrInfo.attrName_, *reinterpret_cast<float*>(attrInfo.attr_.valuePtr_.get()));      \
                break;                                                                                                 \
            }                                                                                                          \
            case Ops::Math::AnyValue::ValueType::VT_STRING: {                                                          \
                contextFaker.Attr(                                                                                     \
                    attrInfo.attrName_,                                                                                \
                    AscendString(reinterpret_cast<std::string*>(attrInfo.attr_.valuePtr_.get())->c_str()));            \
                break;                                                                                                 \
            }                                                                                                          \
            case Ops::Math::AnyValue::ValueType::VT_LIST_BOOL: {                                                       \
                contextFaker.Attr(                                                                                     \
                    attrInfo.attrName_, *reinterpret_cast<std::vector<bool>*>(attrInfo.attr_.valuePtr_.get()));        \
                break;                                                                                                 \
            }                                                                                                          \
            case Ops::Math::AnyValue::ValueType::VT_LIST_INT: {                                                        \
                contextFaker.Attr(                                                                                     \
                    attrInfo.attrName_, *reinterpret_cast<std::vector<int64_t>*>(attrInfo.attr_.valuePtr_.get()));     \
                break;                                                                                                 \
            }                                                                                                          \
            case Ops::Math::AnyValue::ValueType::VT_LIST_LIST_INT: {                                                   \
                contextFaker.Attr(                                                                                     \
                    attrInfo.attrName_,                                                                                \
                    *reinterpret_cast<std::vector<std::vector<int64_t>>*>(attrInfo.attr_.valuePtr_.get()));            \
                break;                                                                                                 \
            }                                                                                                          \
            case Ops::Math::AnyValue::ValueType::VT_LIST_FLOAT: {                                                      \
                contextFaker.Attr(                                                                                     \
                    attrInfo.attrName_, *reinterpret_cast<std::vector<float>*>(attrInfo.attr_.valuePtr_.get()));       \
                break;                                                                                                 \
            }                                                                                                          \
            default:                                                                                                   \
                std::cout << "[ERROR]" << __FILE__ << ":" << __LINE__ << "The ValueType " << attrInfo.attr_.type_      \
                          << "is not supported!" << std::endl;                                                         \
        }                                                                                                              \
    }                                                                                                                  \
    auto contextHolder = contextFaker.SetOpType(infershapeContextPara.opName_.c_str()).Build();                        \
    /* 2. get infershape func */                                                                                       \
    auto spaceRegistry = gert::DefaultOpImplSpaceRegistryV2::GetInstance().GetSpaceRegistry();                         \
    auto infershapeFunc = spaceRegistry->GetOpImpl(infershapeContextPara.opName_.c_str())->infer_shape;                \
    /* 3. check infershape func */                                                                                     \
    auto infershapeRet = infershapeFunc(contextHolder.GetContext());

static std::vector<int64_t> ToVector(const gert::Shape& shape)
{
    size_t shapeSize = shape.GetDimNum();
    std::vector<int64_t> shapeVec(shapeSize, 0);

    for (size_t i = 0; i < shapeSize; i++) {
        shapeVec[i] = shape.GetDim(i);
    }
    return shapeVec;
}

void ExecuteTestCase(
    gert::InfershapeContextPara& infershapeContextPara, ge::graphStatus expectResult,
    const std::vector<std::vector<int64_t>>& expectOutputShape)
{
    DO_INFERSHAPE(infershapeContextPara);

    // check infershape func
    EXPECT_EQ(infershapeRet, expectResult);
    if (expectResult == ge::GRAPH_FAILED) {
        return;
    }

    // check output shape
    for (int i = 0; i < expectOutputShape.size(); i++) {
        // EXPECT_EQ(ToVector(*contextHolder.GetContext()->GetOutputShape(i)), expectOutputShape[i]);
    }
}
